#!/usr/bin/env python
# coding: utf-8

# In[2]:

import os
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rcParams
import pandas as pd
from pathos.multiprocessing import cpu_count, Pool
import emcee
from functools import partial

from operator import add
from functools import reduce
from contextlib import closing

rcParams['figure.figsize'] = (13.0, 7.0)
rcParams['lines.linewidth'] = 3
rcParams['font.size'] = 18

from jm.library.data_helper import filter_data, is_between
from jm.library.likelihood import Likelihood
from jm.library.simulator import TargetPrioritiesTargetActions, CounterfactualConfigs
from jm.library.additional_helpers import plot_over_time, payment_cols, priority_cols, action_cols, \
    relative_payment_cols, fwdiff_payment_cols, relative_fwdiff_payment_cols


def is_ever(df, cols, value):
    global REF_DF 
    global REF_COLS 
    global REF_VALUE 
    
    REF_DF = df
    REF_COLS = cols
    REF_VALUE = value

    with Pool() as pool:
        list_is = pool.map(
            lambda i: REF_DF.loc[:, REF_COLS[:i+1]].apply(lambda r: np.isin(REF_VALUE, r), axis=1, raw=True), 
            list(range(len(REF_COLS))))
    this_df = pd.concat(list_is, axis=1)
    this_df.columns = Likelihood.event_dates
    return 1 * this_df


# In[6]:


rcParams['figure.figsize'] = (13.0, 7.0)
rcParams['lines.linewidth'] = 3
rcParams['font.size'] = 18


# In[7]:

df_status = pd.read_csv('data/jesus_maria_status_deidentified.csv')


# In[9]:

df_status[relative_payment_cols] = df_status[payment_cols].divide(
                df_status['total_due'], axis=0)


df_status[relative_fwdiff_payment_cols] = df_status[fwdiff_payment_cols].divide(
    df_status['total_due'], axis=0)


# In[61]:


list_tax_due = [99, 200, 300, 400, 500, 600, 700, 800, 900, 1000, 1250, 1500, 2000, 60000]

# In[63]:


map_payments = {}

for i, t in enumerate(list_tax_due[:-1]):
    l, u = list_tax_due[i:i+2]
    payments = filter_data(
        df_status, total_due=is_between(l, u))[
        relative_fwdiff_payment_cols].to_numpy().flatten()
    positive_payments = payments[(payments > 0)]
    map_payments[t] = positive_payments


# In[64]:


df_treatment = filter_data(df_status, assignment_to_treatment=1)
df_control = filter_data(df_status, assignment_to_treatment=0)


# In[65]:


map_payments_treatment = {}

for i, t in enumerate(list_tax_due[:-1]):
    l, u = list_tax_due[i:i+2]
    payments = filter_data(
        df_treatment, total_due=is_between(l, u))[
        relative_fwdiff_payment_cols].to_numpy().flatten()
    positive_payments = payments[(payments > 0)]
    map_payments_treatment[t] = positive_payments


# In[66]:


map_payments_control = {}

for i, t in enumerate(list_tax_due[:-1]):
    l, u = list_tax_due[i:i+2]
    payments = filter_data(
        df_control, total_due=is_between(l, u))[
        relative_fwdiff_payment_cols].to_numpy().flatten()
    positive_payments = payments[(payments > 0)]
    map_payments_control[t] = positive_payments


# In[16]:


payments = df_status.groupby('assignment_to_treatment')['payments_by_2021-09-06'].sum()


# In[17]:


PARAM_NAMES = ('some_payment',
 'payment',
 'priority_G1',
 'priority_G2',
 'priority_G3',
 'action_valor',
 'action_rec1',
 'action_medida',
 'covariate',
 'G1_rec1',
 'G1_medida',
 's_shape_low',
 's_shape_high',
 'type_sdev')


# In[18]:


backend = emcee.backends.HDFBackend(
    'estimation/parameter_estimates/jesus_maria_mcmc_G1actioninteraction.h5')
state = backend.get_last_sample()

ndiscard = 3000
if os.environ.get("test_size")=="true":
    ndiscard = 0

nthin = 1

chain_array = backend.get_chain(flat=False)
chain_array = chain_array[:, :, :]
flat_len = int(
    128 * (chain_array.shape[0]-ndiscard) / nthin)
chain_array = chain_array[
    ndiscard::nthin,:,:]

chain_array_flat = chain_array.reshape(
    flat_len, chain_array.shape[2])

df_params = pd.DataFrame(chain_array_flat, columns=PARAM_NAMES)
params = np.array(list(df_params.mean()))

sigma_theta = params[-1]

n = len(df_treatment)

num_sim = 400
if os.environ.get("test_size")=="true":
    num_sim = 1

configs = CounterfactualConfigs()
initial_G1 = 350


# In[20]:


df_treatment = df_treatment.sort_values(
    'effective_score', ascending=False, kind='mergesort')
state_0_selected_cols = ['total_due'] + [payment_cols[0], priority_cols[0], action_cols[0]] + [
    'prob_repayment_endo_covariates']


# In[21]:


cols = [['payments_by_{}'.format(d), 'priority_by_{}'.format(d),
         'action_by_{}'.format(d)] for d in [_d.strftime('%Y-%m-%d') for _d
                                             in Likelihood.event_dates]]

cols = reduce(add, cols)


# In[22]:


def run_simulator(sim_seed, initial_G1=initial_G1, use_df_control=False):
    sim, seed = sim_seed
    sim.reset()
    np.random.seed(seed)
    s0 = df_treatment.loc[:, state_0_selected_cols]
    if use_df_control:
        s0 = df_control.loc[:, state_0_selected_cols]
    s0[priority_cols[0]]="G3"
    s0.loc[s0[0:initial_G1].index, priority_cols[0]]="G1"
    if use_df_control:
        s0['theta'] = sigma_theta * np.random.randn(len(df_control))
    else:
        s0['theta'] = sigma_theta * np.random.randn(n)
    s0 = s0.values
    list_states = [s0[:, [1, 2, 3]]]
    s = np.array(s0)
    for d in range(21):
        s = np.array(sim(s, d))
        list_states.append(s[:, [1, 2, 3]])
    sim_res = np.concatenate(list_states, axis=1)
    return pd.DataFrame(sim_res, columns=cols)


# In[23]:


def seeded_simulator(sim, base_seed, num_sim, use_df_control=False):
    if use_df_control:
        return list(zip(len(df_control) * [ sim ], range(base_seed, base_seed + num_sim)))
    else:
        return list(zip(n * [ sim ], range(base_seed, base_seed + num_sim)))


# In[24]:


# 'base sim - as implemented'

base_config = configs.longer_G1_more_G1_actual

np.random.seed(1234)
base_simulator = TargetPrioritiesTargetActions(
    params, map_payments, n=n, config=base_config, list_tax_due=list_tax_due, 
    timetrend=False, enriched_payments=False)

with closing(Pool(cpu_count()-2)) as pool:
    base_res = pool.map(run_simulator, seeded_simulator(base_simulator, 1234, num_sim))

payments_base_res = [df[payment_cols].sum(axis=0) for df in base_res]
base_res_inrec1 = [df[action_cols]=='rec1' for df in base_res]
df_base_res = pd.DataFrame(payments_base_res)
df_base_res.columns = Likelihood.event_dates


# 'control sim'
df_treatment = df_treatment.sort_values(
    'total_due', ascending=False, kind='mergesort')
np.random.seed(1234)

run_simulator_control = partial(run_simulator, initial_G1=0)

control_config = configs.control_config
params_control = params.copy()
params_control[2:5] = 3*[0]
params_control[9:11] = 2*[0]

control_simulator = TargetPrioritiesTargetActions(
    params_control, map_payments, n=n, config=control_config, list_tax_due=list_tax_due, 
    timetrend=False, enriched_payments=False)

np.random.seed(1234)
with closing(Pool(cpu_count()-2)) as pool:
    control_res = pool.map(run_simulator_control, seeded_simulator(control_simulator, 1234, num_sim))

    
payments_control_res = [df[payment_cols].sum(axis=0) for df in control_res]
df_control_res = pd.DataFrame(payments_control_res)
df_control_res.columns = Likelihood.event_dates


# # Comparison Base Simulation to Actual Treatment

base_res_inG1_t = [df[priority_cols]=='G1' for df in base_res]
base_res_inG1 = [df[priority_cols].sum(axis=0) for df in base_res_inG1_t]
base_res_inG1 = pd.DataFrame(base_res_inG1)
base_res_inG1.columns = Likelihood.event_dates

base_res_inmedida_t = [df[action_cols]=='medida' for df in base_res]
base_res_inmedida = [df[action_cols].sum(axis=0) for df in base_res_inmedida_t]
base_res_inmedida = pd.DataFrame(base_res_inmedida)
base_res_inmedida.columns = Likelihood.event_dates

base_res_everrec1_t = [(df[action_cols]=='rec1') + (df[action_cols]=='medida') for df in base_res]
base_res_everrec1 = [df[action_cols].sum(axis=0) for df in base_res_everrec1_t]
base_res_everrec1 = pd.DataFrame(base_res_everrec1)
base_res_everrec1.columns = Likelihood.event_dates

base_res_evervalor_t = [df[action_cols]!='N' for df in base_res]
base_res_evervalor = [df[action_cols].sum(axis=0) for df in base_res_evervalor_t]
base_res_evervalor = pd.DataFrame(base_res_evervalor)
base_res_evervalor.columns = Likelihood.event_dates

plt.clf()
fig, ((ax1, ax2, ax3), (ax4, ax5, ax6)) = plt.subplots(2, 3)

plot_over_time(df_treatment, ax=ax1, normalize=1e6)
df_base_res_norm = df_base_res/1e6
df_base_res_norm.mean().plot(ax=ax1)
ax1.set_title('cumulative taxes collected')
ax1.set_ylabel('M soles')
ax1.legend(['actual', 'sim'], loc="upper left", fontsize=13)

plot_over_time(df_treatment, col_fmt='relative_payments_by_{}', func=np.mean, ax=ax2)
plot_over_time(
    pd.DataFrame([base_res[j][payment_cols].div(
        df_treatment.sort_values(
    'effective_score', ascending=False, kind='mergesort').reset_index()['total_due'], 
        axis=0).mean() for j in range(0, len(base_res))]), 
    col_fmt='payments_by_{}', func=np.mean, ax=ax2)

ax2.set_title('mean relative payment')
ax2.legend(['actual', 'sim'], loc="upper left", fontsize=13)

is_ever(df_treatment, priority_cols, 'G1').sum(axis=0).plot(ax=ax3)
base_res_inG1.mean().plot(ax=ax3)
ax3.set_title('number with G1 priority')
ax3.legend(['actual', 'sim'], loc="upper left", fontsize=13)

is_ever(df_treatment, action_cols, 'valor').sum(axis=0).plot(ax=ax4)
base_res_evervalor.mean().plot(ax=ax4)
ax4.set_title('number with notifications')
ax4.legend(['actual', 'sim'], loc="upper left", fontsize=13)

is_ever(df_treatment, action_cols, 'medida').sum(axis=0).plot(ax=ax5)
base_res_inmedida.mean().plot(ax=ax5)
ax5.set_title('number of garnishments')
ax5.legend(['actual', 'sim'], loc="upper left", fontsize=13)

is_ever(df_treatment, action_cols, 'rec1').sum(axis=0).plot(ax=ax6)
base_res_everrec1.mean().plot(ax=ax6)
ax6.set_title('number of writs')
ax6.legend(['actual', 'sim'], loc="upper left", fontsize=13)

fig.tight_layout()

plt.savefig('figs/figOB1.pdf')

# # Comparison Control Simulation to Actual Control


control_res_inmedida_t = [df[action_cols]=='medida' for df in control_res]
control_res_inmedida = [df[action_cols].sum(axis=0) for df in control_res_inmedida_t]
control_res_inmedida = pd.DataFrame(control_res_inmedida)
control_res_inmedida.columns = Likelihood.event_dates

control_res_everrec1_t = [(df[action_cols]=='rec1') + (df[action_cols]=='medida') for df in control_res]
control_res_everrec1 = [df.sum(axis=0) for df in control_res_everrec1_t]
control_res_everrec1 = pd.DataFrame(control_res_everrec1)
control_res_everrec1.columns = Likelihood.event_dates

control_res_evervalor_t = [df[action_cols]!='N' for df in control_res]
control_res_evervalor = [df[action_cols].sum(axis=0) for df in control_res_evervalor_t]
control_res_evervalor = pd.DataFrame(control_res_evervalor)
control_res_evervalor.columns = Likelihood.event_dates


plt.clf()
fig, ((ax1, ax2, ax3), (ax4, ax5, ax6)) = plt.subplots(2, 3)
fig.delaxes(ax3)

plot_over_time(df_control, ax=ax1, normalize=1e6)
df_control_res_norm = df_control_res/1e6
df_control_res_norm.mean().plot(ax=ax1)
ax1.set_title('cumulative taxes collected')
ax1.set_ylabel('M soles')
ax1.legend(['actual', 'sim'], loc="upper left", fontsize=13)

plot_over_time(df_control, col_fmt='relative_payments_by_{}', func=np.mean, ax=ax2)
plot_over_time(
    pd.DataFrame([control_res[j][payment_cols].div(
        df_treatment.sort_values(
    'total_due', ascending=False, kind='mergesort').reset_index()['total_due'], 
        axis=0).mean() for j in range(0, len(base_res))]), 
    col_fmt='payments_by_{}', func=np.mean, ax=ax2)

ax2.set_title('mean relative payment')
ax2.legend(['actual', 'sim'], loc="upper left", fontsize=13)


is_ever(df_control, action_cols, 'valor').sum(axis=0).plot(ax=ax4)
control_res_evervalor.mean().plot(ax=ax4)
ax4.set_title('number with notifications')
ax4.legend(['actual', 'sim'], loc="upper left", fontsize=13)

is_ever(df_control, action_cols, 'medida').sum(axis=0).plot(ax=ax5)
control_res_inmedida.mean().plot(ax=ax5)
ax5.set_title('number of garnishments')
ax5.legend(['actual', 'sim'], loc="upper left", fontsize=13)

is_ever(df_control, action_cols, 'rec1').sum(axis=0).plot(ax=ax6)
control_res_everrec1.mean().plot(ax=ax6)
ax6.set_title('number of writs')
ax6.legend(['actual', 'sim'], loc="upper left", fontsize=13)

fig.tight_layout()

plt.savefig('figs/figOB2.pdf')
